# -*- coding: utf-8 -*-
#
# generate_gif.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST.  If not, see <http://www.gnu.org/licenses/>.

r"""Script to visualize a simulated Pong game.
----------------------------------------------------------------
All simulations store data about both networks and the game in .pkl files.
This script reads these files and generates image snapshots at different
times during the simulation. These are subsequently aggregated into a GIF.

:Authors: J Gille, T Wunderlich, Electronic Vision(s)
"""

import gzip
import os
import pickle
import sys
from copy import copy
from glob import glob

import imageio.v2 as imageio
import matplotlib.pyplot as plt
import numpy as np
from pong import GameOfPong as Pong

gridsize = (12, 16)  # Shape of the grid used for positioning subplots

left_color = np.array((204, 0, 153))  # purple
right_color = np.array((255, 128, 0))  # orange
left_color_hex = "#cc0099"
right_color_hex = "#ff8000"
white = np.array((255, 255, 255))

# Original size of the playing field inside the simulation
GAME_GRID = np.array([Pong.x_grid, Pong.y_grid])
GRID_SCALE = 24
# Field size (in px) after upscaling
GAME_GRID_SCALED = GAME_GRID * GRID_SCALE

# Dimensions of game objects in px
BALL_RAD = 6
PADDLE_LEN = int(0.1 * GAME_GRID_SCALED[1])
PADDLE_WID = 18

# Add margins left and right to the playing field
FIELD_PADDING = PADDLE_WID * 2
FIELD_SIZE = copy(GAME_GRID_SCALED)
FIELD_SIZE[0] += 2 * FIELD_PADDING

# At default, the GIF shows every DEFAULT_SPEEDth simulation step.
DEFAULT_SPEED = 4


def scale_coordinates(coordinates: np.array):
    """Scales a numpy.array of coordinate tuples (x,y) from simulation scale to
    pixel scale in the output image.

    Args:
        pos (float, float): input coordinates to be scaled.

    Returns:
        (int, int): output coordinates in px
    """
    coordinates[:, 0] = coordinates[:, 0] * GAME_GRID_SCALED[0] / Pong.x_length + FIELD_PADDING
    coordinates[:, 1] = coordinates[:, 1] * GAME_GRID_SCALED[1] / Pong.y_length
    return coordinates.astype(int)


def grayscale_to_heatmap(in_image, min_val, max_val, base_color):
    """Transforms a grayscale image to an RGB heat map. Heatmap will color small
    values in base_color and high values in white.

    Args:
        in_image (numpy.array): 2D numpy.array to be transformed.
        min_val (float): smallest value across the entire image - colored in
        base_color in the output.
        max_val (float): largest value across the entire image - colored
        white in the output.
        base_color (numpy.array): numpy.array of shape (3,) representing the
        base color of the heatmap in RGB.
    Returns:
        numpy.array: transformed input array with an added 3rd dimension of
        length 3, representing RGB values.
    """

    x_len, y_len = in_image.shape
    out_image = np.ones((x_len, y_len, 3), dtype=np.uint8)

    span = max_val - min_val
    # Edge case for uniform weight matrix
    if span == 0:
        return out_image * base_color

    for x in range(x_len):
        for y in range(y_len):
            color_scaled = (in_image[x, y] - min_val) / span
            out_image[x, y, :] = base_color + (white - base_color) * color_scaled

    return out_image


if __name__ == "__main__":
    keep_temps = False
    out_file = "pong_sim.gif"

    if len(sys.argv) != 2:
        print(
            "This programm takes exactly one argument - the location of the "
            "output folder generated by the simulation."
        )
        sys.exit(1)
    input_folder = sys.argv[1]

    if os.path.exists(out_file):
        print(f"<{out_file}> already exists, aborting!")
        sys.exit(1)

    temp_dir = "temp"

    if os.path.exists(temp_dir):
        print(f"Output folder <{temp_dir}> already exists, aborting!")
        sys.exit(1)
    else:
        os.mkdir(temp_dir)

    print(f"Reading simulation data from {input_folder}...")
    with open(os.path.join(input_folder, "gamestate.pkl"), "rb") as f:
        game_data = pickle.load(f)

    ball_positions = scale_coordinates(np.array(game_data["ball_pos"]))
    l_paddle_positions = scale_coordinates(np.array(game_data["left_paddle"]))
    # Move left paddle outwards for symmetry
    l_paddle_positions[:, 0] -= PADDLE_WID
    r_paddle_positions = scale_coordinates(np.array(game_data["right_paddle"]))

    score = np.array(game_data["score"]).astype(int)

    with gzip.open(os.path.join(input_folder, "data_left.pkl.gz"), "r") as f:
        data = pickle.load(f)
        rewards_left = data["rewards"]
        weights_left = data["weights"]
        name_left = data["network_type"]

    with gzip.open(os.path.join(input_folder, "data_right.pkl.gz"), "r") as f:
        data = pickle.load(f)
        rewards_right = data["rewards"]
        weights_right = data["weights"]
        name_right = data["network_type"]

    # Extract lowest and highest weights for both players to scale the heatmaps.
    min_r, max_r = np.min(weights_right), np.max(weights_right)
    min_l, max_l = np.min(weights_left), np.max(weights_left)

    # Average rewards at every iteration over all neurons
    rewards_left = [np.mean(x) for x in rewards_left]
    rewards_right = [np.mean(x) for x in rewards_right]

    print(f"Setup complete, generating images to '{temp_dir}'...")
    n_iterations = score.shape[0]
    i = 0
    output_speed = DEFAULT_SPEED

    while i < n_iterations:
        px = 1 / plt.rcParams["figure.dpi"]
        fig, ax = plt.subplots(figsize=(400 * px, 300 * px))
        ax.set_axis_off()
        plt.rcParams.update({"font.size": 6})
        # Set up the grid containing all components of the output image
        title = plt.subplot2grid(gridsize, (0, 0), 1, 16)
        l_info = plt.subplot2grid(gridsize, (1, 0), 7, 2)
        r_info = plt.subplot2grid(gridsize, (1, 14), 7, 2)
        field = plt.subplot2grid(gridsize, (1, 2), 7, 12)
        l_hm = plt.subplot2grid(gridsize, (8, 0), 4, 4)
        reward_plot = plt.subplot2grid(gridsize, (8, 6), 4, 6)
        r_hm = plt.subplot2grid(gridsize, (8, 12), 4, 4)

        for ax in [title, l_info, r_info, field, l_hm, r_hm]:
            ax.axis("off")

        # Create an empty array for the playing field.
        playing_field = np.zeros((FIELD_SIZE[0], FIELD_SIZE[1], 3), dtype=np.uint8)

        # Draw the ball in white
        x, y = ball_positions[i]
        playing_field[x - BALL_RAD : x + BALL_RAD, y - BALL_RAD : y + BALL_RAD] = white
        for (x, y), color in zip([l_paddle_positions[i], r_paddle_positions[i]], [left_color, right_color]):
            # Clip y coordinate of the paddle so it does not exceed the screen
            y = max(PADDLE_LEN, y)
            y = min(FIELD_SIZE[1] - PADDLE_LEN, y)
            playing_field[x : x + PADDLE_WID, y - PADDLE_LEN : y + PADDLE_LEN] = color

        field.imshow(np.transpose(playing_field, [1, 0, 2]))

        # Left player heatmap
        heatmap_l = grayscale_to_heatmap(weights_left[i], min_l, max_l, left_color)
        l_hm.imshow(heatmap_l)
        l_hm.set_xlabel("output")
        l_hm.set_ylabel("input")
        l_hm.set_title("weights", y=-0.3)

        # Right player heatmap
        heatmap_r = grayscale_to_heatmap(weights_right[i], min_r, max_r, right_color)
        r_hm.imshow(heatmap_r)
        r_hm.set_xlabel("output")
        r_hm.set_ylabel("input")
        r_hm.set_title("weights", y=-0.3)

        reward_plot.plot([0, i], [-1, -1])
        reward_plot.plot(rewards_right[: i + 1], color=right_color / 255)
        reward_plot.plot(rewards_left[: i + 1], color=left_color / 255)

        # Change x_ticks and x_min for the first few plots
        if i < 1600:
            x_min = 0
            reward_plot.set_xticks(np.arange(0, n_iterations, 250))
        else:
            x_min = i - 1600
            reward_plot.set_xticks(np.arange(0, n_iterations, 500))

        reward_plot.set_ylabel("mean reward")
        reward_plot.set_yticks([0, 0.5, 1])
        reward_plot.set_ylim(0, 1.0)
        reward_plot.set_xlim(x_min, i + 10)

        title.text(0.4, 0.75, name_left, ha="right", fontsize=15, c=left_color_hex)
        title.text(0.5, 0.75, "VS", ha="center", fontsize=17)
        title.text(0.6, 0.75, name_right, ha="left", fontsize=15, c=right_color_hex)

        l_score, r_score = score[i]

        l_info.text(0, 0.9, "run:", fontsize=14)
        l_info.text(0, 0.75, str(i), fontsize=14)
        l_info.text(1, 0.5, l_score, ha="right", va="center", fontsize=26, c=left_color_hex)

        r_info.text(0, 0.9, "speed:", fontsize=14)
        r_info.text(0, 0.75, str(output_speed) + "x", fontsize=14)
        r_info.text(0, 0.5, r_score, ha="left", va="center", fontsize=26, c=right_color_hex)

        plt.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.35, hspace=0.35)
        plt.savefig(os.path.join(temp_dir, f"img_{str(i).zfill(6)}.png"))

        # Change the speed of the video to show performance before and after
        # training at DEFAULT_SPEED and fast-forward most of the training
        if 75 <= i < 100 or n_iterations - 400 <= i < n_iterations - 350:
            output_speed = 10
        elif 100 <= i < n_iterations - 350:
            output_speed = 50
        else:
            output_speed = DEFAULT_SPEED

        i += output_speed
        plt.close()

    print("Image creation complete, collecting them into a GIF...")

    filenames = sorted(glob(os.path.join(temp_dir, "*.png")))

    with imageio.get_writer(out_file, mode="I", duration=150) as writer:
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)
    print(f"GIF created under: {out_file}")

    if not keep_temps:
        print("Deleting temporary image files...")
        for in_file in filenames:
            os.unlink(in_file)
        os.rmdir(temp_dir)

    print("Done.")
